{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9d39d946",
   "metadata": {},
   "source": [
    "# Comparison Between TreeValue and Tianshou Batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c6db2d4",
   "metadata": {},
   "source": [
    "In this section, we will take a look at the feature and performance of the [Tianshou Batch](https://github.com/thu-ml/tianshou) library, which is developed by Tsinghua Machine Learning Group."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "069361b0",
   "metadata": {},
   "source": [
    "Before starting the comparison, let us define some thing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06fc8d26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "\n",
    "_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}\n",
    "_TREE_DATA_2 = {\n",
    "    'a': torch.randn(2, 3), \n",
    "    'x': {\n",
    "        'c': torch.randn(3, 4)\n",
    "    },\n",
    "}\n",
    "_TREE_DATA_3 = {\n",
    "    'obs': torch.randn(4, 84, 84),\n",
    "    'action': torch.randint(0, 6, size=(1,)),\n",
    "    'reward': torch.rand(1),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83461b25",
   "metadata": {},
   "source": [
    "## Read and Write Operation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "067b3f73",
   "metadata": {},
   "source": [
    "Reading and writing are the two most common operations in the tree data structure based on the data model (TreeValue and Tianshou Batch both belong to this type), so this section will compare the reading and writing performance of these two libraries."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d09a5b7",
   "metadata": {},
   "source": [
    "### TreeValue's Get and Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9519c4bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from treevalue import FastTreeValue\n",
    "\n",
    "t = FastTreeValue(_TREE_DATA_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11c37677",
   "metadata": {},
   "outputs": [],
   "source": [
    "t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd70b0b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "t.a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c18197bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit t.a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd52f867",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_value = torch.randn(2, 3)\n",
    "t.a = new_value\n",
    "\n",
    "t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbe04d1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit t.a = new_value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48c49731",
   "metadata": {},
   "source": [
    "### Tianshou Batch's Get and Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1bb14c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tianshou.data import Batch\n",
    "\n",
    "b = Batch(**_TREE_DATA_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb0777c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43ef8ea3",
   "metadata": {},
   "outputs": [],
   "source": [
    "b.a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b785ab72",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit b.a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad54dc69",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_value = torch.randn(2, 3)\n",
    "b.a = new_value\n",
    "\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29b1d0bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit b.a = new_value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b61ad1d0",
   "metadata": {},
   "source": [
    "## Initialization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d70f0d54",
   "metadata": {},
   "source": [
    "### TreeValue's Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32a679b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit FastTreeValue(_TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24f3707b",
   "metadata": {},
   "source": [
    "### Tianshou Batch's Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac3958df",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit Batch(**_TREE_DATA_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ab82e2d",
   "metadata": {},
   "source": [
    "## Deep Copy Operation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "210a9442",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a736274",
   "metadata": {},
   "source": [
    "### Deep Copy of TreeValue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9bcadd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "t3 = FastTreeValue(_TREE_DATA_3)\n",
    "%timeit copy.deepcopy(t3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf8be7ea",
   "metadata": {},
   "source": [
    "### Deep Copy of Tianshou Batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91998e6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "b3 = Batch(**_TREE_DATA_3)\n",
    "%timeit copy.deepcopy(b3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "223162fb",
   "metadata": {},
   "source": [
    "## Stack, Concat and Split Operation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85fa4a73",
   "metadata": {},
   "source": [
    "### Performance of TreeValue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c2b697",
   "metadata": {},
   "outputs": [],
   "source": [
    "trees = [FastTreeValue(_TREE_DATA_2) for _ in range(8)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "017ea5a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_stack = FastTreeValue.func(subside=True)(torch.stack)\n",
    "\n",
    "t_stack(trees)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8b3f415",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit t_stack(trees)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94b56771",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_cat = FastTreeValue.func(subside=True)(torch.cat)\n",
    "\n",
    "t_cat(trees)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e9c06a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit t_cat(trees)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3ab5c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_split = FastTreeValue.func(rise=True)(torch.split)\n",
    "tree = FastTreeValue({\n",
    "    'obs': torch.randn(8, 4, 84, 84),\n",
    "    'action': torch.randint(0, 6, size=(8, 1,)),\n",
    "    'reward': torch.rand(8, 1),\n",
    "})\n",
    "\n",
    "%timeit t_split(tree, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31c3ec0b",
   "metadata": {},
   "source": [
    "### Performance of Tianshou Batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ead828a",
   "metadata": {},
   "outputs": [],
   "source": [
    "batches = [Batch(**_TREE_DATA_2) for _ in range(8)]\n",
    "\n",
    "Batch.stack(batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9037a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit Batch.stack(batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb8ab77e",
   "metadata": {},
   "outputs": [],
   "source": [
    "Batch.cat(batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18dfb045",
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit Batch.cat(batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6688e51",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = Batch({\n",
    "    'obs': torch.randn(8, 4, 84, 84),\n",
    "    'action': torch.randint(0, 6, size=(8, 1,)),\n",
    "    'reward': torch.rand(8, 1)}\n",
    ")\n",
    "\n",
    "%timeit list(Batch.split(batch, 1, shuffle=False, merge_last=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2539fbd9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
